import argparse
from collections import OrderedDict

import mmcv
import torch

arch_settings = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3)}


def convert_bn(blobs, state_dict, caffe_name, torch_name, converted_names):
    # detectron replace bn with affine channel layer
    state_dict[torch_name + ".bias"] = torch.from_numpy(blobs[caffe_name + "_b"])
    state_dict[torch_name + ".weight"] = torch.from_numpy(blobs[caffe_name + "_s"])
    bn_size = state_dict[torch_name + ".weight"].size()
    state_dict[torch_name + ".running_mean"] = torch.zeros(bn_size)
    state_dict[torch_name + ".running_var"] = torch.ones(bn_size)
    converted_names.add(caffe_name + "_b")
    converted_names.add(caffe_name + "_s")


def convert_conv_fc(blobs, state_dict, caffe_name, torch_name, converted_names):
    state_dict[torch_name + ".weight"] = torch.from_numpy(blobs[caffe_name + "_w"])
    converted_names.add(caffe_name + "_w")
    if caffe_name + "_b" in blobs:
        state_dict[torch_name + ".bias"] = torch.from_numpy(blobs[caffe_name + "_b"])
        converted_names.add(caffe_name + "_b")


def convert(src, dst, depth):
    """Convert keys in detectron pretrained ResNet models to pytorch style."""
    # load arch_settings
    if depth not in arch_settings:
        raise ValueError("Only support ResNet-50 and ResNet-101 currently")
    block_nums = arch_settings[depth]
    # load caffe model
    caffe_model = mmcv.load(src, encoding="latin1")
    blobs = caffe_model["blobs"] if "blobs" in caffe_model else caffe_model
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    convert_conv_fc(blobs, state_dict, "conv1", "conv1", converted_names)
    convert_bn(blobs, state_dict, "res_conv1_bn", "bn1", converted_names)
    for i in range(1, len(block_nums) + 1):
        for j in range(block_nums[i - 1]):
            if j == 0:
                convert_conv_fc(
                    blobs,
                    state_dict,
                    f"res{i + 1}_{j}_branch1",
                    f"layer{i}.{j}.downsample.0",
                    converted_names,
                )
                convert_bn(
                    blobs,
                    state_dict,
                    f"res{i + 1}_{j}_branch1_bn",
                    f"layer{i}.{j}.downsample.1",
                    converted_names,
                )
            for k, letter in enumerate(["a", "b", "c"]):
                convert_conv_fc(
                    blobs,
                    state_dict,
                    f"res{i + 1}_{j}_branch2{letter}",
                    f"layer{i}.{j}.conv{k+1}",
                    converted_names,
                )
                convert_bn(
                    blobs,
                    state_dict,
                    f"res{i + 1}_{j}_branch2{letter}_bn",
                    f"layer{i}.{j}.bn{k + 1}",
                    converted_names,
                )
    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f"Not Convert: {key}")
    # save checkpoint
    checkpoint = dict()
    checkpoint["state_dict"] = state_dict
    torch.save(checkpoint, dst)


def main():
    parser = argparse.ArgumentParser(description="Convert model keys")
    parser.add_argument("src", help="src detectron model path")
    parser.add_argument("dst", help="save path")
    parser.add_argument("depth", type=int, help="ResNet model depth")
    args = parser.parse_args()
    convert(args.src, args.dst, args.depth)


if __name__ == "__main__":
    main()
